import numpy as np
import paddle
import random


def set_seed(seed=1000):
    """设置随机种子"""
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)

# 创建数据集加载器
def create_dataloader(dataset,
                      trans_fn=None,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)
    
    # 如果数据用于训练，则随机打乱数据顺序
    shuffle = True if mode == 'train' else False
    if mode == "train":
        # 设置训练数据采样器
        sampler = paddle.io.DistributedBatchSampler(
            dataset=dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        # 设置测试数据采样器
        sampler = paddle.io.BatchSampler(
            dataset=dataset, batch_size=batch_size, shuffle=shuffle)
    # 设置数据加载器
    dataloader = paddle.io.DataLoader(
        dataset, batch_sampler=sampler, collate_fn=batchify_fn)
    return dataloader